##
# Copyright (c) 2008-2017 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##

from twisted.python.failure import Failure

from twisted.internet.defer import Deferred, fail
from twisted.internet.protocol import ReconnectingClientFactory
from twisted.protocols.memcache import MemCacheProtocol, NoSuchCommand

from twext.python.log import Logger
from twext.internet.gaiendpoint import GAIEndpoint
from twext.internet.adaptendpoint import connect
from twisted.internet.endpoints import UNIXClientEndpoint


class PooledMemCacheProtocol(MemCacheProtocol):
    """
    A MemCacheProtocol that will notify a connectionPool that it is ready
    to accept requests.

    @ivar factory: A L{MemCacheClientFactory} instance.
    """
    factory = None

    def connectionMade(self):
        """
        Notify our factory that we're ready to accept connections.
        """
        MemCacheProtocol.connectionMade(self)

        if self.factory.deferred is not None:
            self.factory.deferred.callback(self)
            self.factory.deferred = None


class MemCacheClientFactory(ReconnectingClientFactory):
    """
    A client factory for MemCache that reconnects and notifies a pool of it's
    state.

    @ivar connectionPool: A managing connection pool that we notify of events.
    @ivar deferred: A L{Deferred} that represents the initial connection.
    @ivar _protocolInstance: The current instance of our protocol that we pass
        to our connectionPool.
    """
    log = Logger()

    protocol = PooledMemCacheProtocol
    connectionPool = None
    _protocolInstance = None

    def __init__(self):
        self.deferred = Deferred()

    def clientConnectionLost(self, connector, reason):
        """
        Notify the connectionPool that we've lost our connection.
        """

        if self.connectionPool.shutdown_requested:
            # The reactor is stopping; don't reconnect
            return

        self.log.error("MemCache connection lost: {r}", r=reason)
        if self._protocolInstance is not None:
            self.connectionPool.clientBusy(self._protocolInstance)

        ReconnectingClientFactory.clientConnectionLost(
            self,
            connector,
            reason)

    def clientConnectionFailed(self, connector, reason):
        """
        Notify the connectionPool that we're unable to connect
        """
        self.log.error("MemCache connection failed: {r}", r=reason)
        if self._protocolInstance is not None:
            self.connectionPool.clientBusy(self._protocolInstance)

        ReconnectingClientFactory.clientConnectionFailed(
            self,
            connector,
            reason)

    def buildProtocol(self, addr):
        """
        Attach the C{self.connectionPool} to the protocol so it can tell it,
        when we've connected.
        """
        if self._protocolInstance is not None:
            self.connectionPool.clientGone(self._protocolInstance)

        self._protocolInstance = self.protocol()
        self._protocolInstance.factory = self
        return self._protocolInstance


class MemCachePool(object):
    """
    A connection pool for MemCacheProtocol instances.

    @ivar clientFactory: The L{ClientFactory} implementation that will be used
        for each protocol.

    @ivar _maxClients: A C{int} indicating the maximum number of clients.

    @ivar _endpoint: An L{IStreamClientEndpoint} provider indicating the server
        to connect to.

    @ivar _reactor: The L{IReactorTCP} provider used to initiate new
        connections.

    @ivar _busyClients: A C{set} that contains all currently busy clients.

    @ivar _freeClients: A C{set} that contains all currently free clients.

    @ivar _pendingConnects: A C{int} indicating how many connections are in
        progress.
    """
    log = Logger()

    clientFactory = MemCacheClientFactory

    REQUEST_LOGGING_SIZE = 1024

    def __init__(self, endpoint, maxClients=5, reactor=None):
        """
        @param endpoint: An L{IStreamClientEndpoint} indicating the server to
            connect to.

        @param maxClients: A C{int} indicating the maximum number of clients.

        @param reactor: An L{IReactorTCP} provider used to initiate new
            connections.
        """
        self._endpoint = endpoint
        self._maxClients = maxClients

        if reactor is None:
            from twisted.internet import reactor
        self._reactor = reactor

        self.shutdown_deferred = None
        self.shutdown_requested = False
        reactor.addSystemEventTrigger(
            'before', 'shutdown', self._shutdownCallback
        )

        self._busyClients = set([])
        self._freeClients = set([])
        self._pendingConnects = 0
        self._commands = []

    def _isIdle(self):
        return (
            len(self._busyClients) == 0 and
            len(self._commands) == 0 and
            self._pendingConnects == 0
        )

    def _shutdownCallback(self):
        self.shutdown_requested = True
        if self._isIdle():
            return None
        self.shutdown_deferred = Deferred()
        return self.shutdown_deferred

    def _newClientConnection(self):
        """
        Create a new client connection.

        @return: A L{Deferred} that fires with the L{IProtocol} instance.
        """
        self.log.debug(
            "Initiating new client connection to: {r!r}", r=self._endpoint
        )
        self._logClientStats()

        self._pendingConnects += 1

        def _connected(client):
            self._pendingConnects -= 1

            return client

        factory = self.clientFactory()
        factory.noisy = False

        factory.connectionPool = self

        connect(self._endpoint, factory)
        d = factory.deferred

        d.addCallback(_connected)
        return d

    def _performRequestOnClient(self, client, command, *args, **kwargs):
        """
        Perform the given request on the given client.

        @param client: A L{PooledMemCacheProtocol} that will be used to perform
            the given request.

        @param command: A C{str} representing an attribute of
            L{MemCacheProtocol}.
        @param args: Any positional arguments that should be passed to
            C{command}.
        @param kwargs: Any keyword arguments that should be passed to
            C{command}.

        @return: A L{Deferred} that fires with the result of the given command.
        """
        def _freeClientAfterRequest(result):
            self.clientFree(client)
            return result

        def _reportError(failure):
            """
            Upon memcache error, log the failed request along with the error
            message and free the client.
            """
            self.log.error(
                "Memcache error: {ex}; request: {cmd} {args}",
                ex=failure.value,
                cmd=command,
                args=" ".join(args)[:self.REQUEST_LOGGING_SIZE],
            )
            self.clientFree(client)

        self.clientBusy(client)
        method = getattr(client, command, None)
        if method is not None:
            d = method(*args, **kwargs)
        else:
            d = fail(Failure(NoSuchCommand()))

        d.addCallbacks(_freeClientAfterRequest, _reportError)

        return d

    def performRequest(self, command, *args, **kwargs):
        """
        Select an available client and perform the given request on it.

        @param command: A C{str} representing an attribute of
            L{MemCacheProtocol}.
        @param args: Any positional arguments that should be passed to
            C{command}.
        @param kwargs: Any keyword arguments that should be passed to
            C{command}.

        @return: A L{Deferred} that fires with the result of the given command.
        """

        if len(self._freeClients) > 0:
            client = self._freeClients.pop()

            d = self._performRequestOnClient(
                client, command, *args, **kwargs)

        elif (
            len(self._busyClients) + self._pendingConnects >= self._maxClients
        ):
            d = Deferred()
            self._commands.append((d, command, args, kwargs))
            self.log.debug(
                "Command queued: {c}, {a!r}, {k!r}", c=command, a=args, k=kwargs
            )
            self._logClientStats()

        else:
            d = self._newClientConnection()
            d.addCallback(self._performRequestOnClient,
                          command, *args, **kwargs)

        return d

    def _logClientStats(self):
        self.log.debug(
            "Clients #free: {f}, #busy: {b}, #pending: {p}, #queued: {q}",
            f=len(self._freeClients),
            b=len(self._busyClients),
            p=self._pendingConnects,
            q=len(self._commands),
        )

    def clientGone(self, client):
        """
        Notify that the given client is to be removed from the pool completely.

        @param client: An instance of L{PooledMemCacheProtocol}.
        """
        if client in self._busyClients:
            self._busyClients.remove(client)

        elif client in self._freeClients:
            self._freeClients.remove(client)

        self.log.debug("Removed client: {c!r}", c=client)
        self._logClientStats()

    def clientBusy(self, client):
        """
        Notify that the given client is being used to complete a request.

        @param client: An instance of C{self.clientFactory}
        """

        if client in self._freeClients:
            self._freeClients.remove(client)

        self._busyClients.add(client)

        self.log.debug("Busied client: {c!r}", c=client)
        self._logClientStats()

    def clientFree(self, client):
        """
        Notify that the given client is free to handle more requests.

        @param client: An instance of C{self.clientFactory}
        """
        if client in self._busyClients:
            self._busyClients.remove(client)

        self._freeClients.add(client)

        if self.shutdown_deferred and self._isIdle():
            self.shutdown_deferred.callback(None)

        if len(self._commands) > 0:
            d, command, args, kwargs = self._commands.pop(0)

            self.log.debug(
                "Performing Queued Command: {c}, {a}, {k}",
                c=command, a=args, k=kwargs,
            )
            self._logClientStats()

            _ign_d = self.performRequest(
                command, *args, **kwargs)

            _ign_d.addCallback(d.callback)

        self.log.debug("Freed client: {c!r}", c=client)
        self._logClientStats()

    def suggestMaxClients(self, maxClients):
        """
        Suggest the maximum number of concurrently connected clients.

        @param maxClients: A C{int} indicating how many client connections we
            should keep open.
        """
        self._maxClients = maxClients

    def get(self, *args, **kwargs):
        return self.performRequest('get', *args, **kwargs)

    def set(self, *args, **kwargs):
        return self.performRequest('set', *args, **kwargs)

    def checkAndSet(self, *args, **kwargs):
        return self.performRequest('checkAndSet', *args, **kwargs)

    def delete(self, *args, **kwargs):
        return self.performRequest('delete', *args, **kwargs)

    def add(self, *args, **kwargs):
        return self.performRequest('add', *args, **kwargs)

    def incr(self, *args, **kwargs):
        return self.performRequest('increment', *args, **kwargs)

    def decr(self, *args, **kwargs):
        return self.performRequest('decrement', *args, **kwargs)

    def flushAll(self, *args, **kwargs):
        return self.performRequest('flushAll', *args, **kwargs)


class CachePoolUserMixIn(object):
    """
    A mixin that returns a saved cache pool or fetches the default cache pool.

    @ivar _cachePool: A saved cachePool.
    """
    _cachePool = None
    _cachePoolHandle = "Default"

    def getCachePool(self):
        if self._cachePool is None:
            return defaultCachePool(self._cachePoolHandle)

        return self._cachePool


_memCachePools = {}         # Maps a name to a pool object
_memCachePoolHandler = {}   # Maps a handler id to a named pool


def installPools(pools, maxClients=5, reactor=None):
    if reactor is None:
        from twisted.internet import reactor
    for name, pool in pools.items():
        if pool["ClientEnabled"]:
            if pool.get("MemcacheSocket"):
                ep = UNIXClientEndpoint(reactor, pool["MemcacheSocket"])
            else:
                ep = GAIEndpoint(reactor, pool["BindAddress"], pool["Port"])

            _installPool(
                name,
                pool["HandleCacheTypes"],
                ep,
                maxClients,
                reactor,
            )


def _installPool(
    name, handleTypes, serverEndpoint, maxClients=5, reactor=None
):
    pool = MemCachePool(serverEndpoint, maxClients=maxClients, reactor=None)
    _memCachePools[name] = pool

    for handle in handleTypes:
        _memCachePoolHandler[handle] = pool


def defaultCachePool(name):
    if name not in _memCachePoolHandler:
        name = "Default"
    return _memCachePoolHandler[name]
